This script analyzes filtered mAb escape data¶

In [1]:
# this cell is tagged as parameters for `papermill` parameterization
binding_data = None
HENV103_filter = None
HENV117_filter = None
HENV26_filter = None
HENV32_filter = None
m102_filter = None
nAH1_filter = None

altair_config = None
nipah_config = None
escape_bubble_plot = None
bubble_1_mut_plot = None
mab_line_escape_plot = None
aggregate_mab_and_binding = None
aggregate_mab_and_niv_polymorphism = None
binding_vs_escape = None

mab_plot_top = None
mab_plot_all = None
In [2]:
# Parameters
nipah_config = "nipah_config.yaml"
altair_config = "data/custom_analyses_data/theme.py"
HENV103_filter = "results/filtered_data/HENV103_escape_filtered.csv"
HENV117_filter = "results/filtered_data/HENV117_escape_filtered.csv"
HENV26_filter = "results/filtered_data/HENV26_escape_filtered.csv"
HENV32_filter = "results/filtered_data/HENV32_escape_filtered.csv"
m102_filter = "results/filtered_data/m102_escape_filtered.csv"
nAH1_filter = "results/filtered_data/nAH1_escape_filtered.csv"
binding_data = "results/filtered_data/E2_binding_filtered.csv"
escape_bubble_plot = "results/images/escape_bubble_plot.html"
bubble_1_mut_plot = "results/images/escape_bubble_1_mut_plot.html"
overlap_escape_plot = "results/images/overlap_escape_plot.html"
mab_line_escape_plot = "results/images/mab_line_escape_plot.html"
mab_plot_top = "results/images/mab_plot_top.html"
mab_plot_all = "results/images/mab_plot_all.html"
aggregate_mab_and_binding = "results/images/aggregate_mab_and_binding.html"
binding_vs_escape = "results/images/binding_vs_escape.html"
aggregate_mab_and_niv_polymorphism = (
    "results/images/aggregate_mab_and_niv_polymorphism.html"
)
In [3]:
if binding_data is None:
    print("this is being run manually")
else:
    print("papermill!")
papermill!
In [4]:
import math
import os
import re

import altair as alt

import numpy as np

import pandas as pd

import scipy.stats

import Bio.SeqIO
import yaml
import matplotlib

matplotlib.rcParams["svg.fonttype"] = "none"

from Bio import PDB
import dmslogo
from dmslogo.colorschemes import CBPALETTE
from dmslogo.colorschemes import ValueToColorMap
In [5]:
# allow more rows for Altair
_ = alt.data_transformers.disable_max_rows()

if (
    os.getcwd()
    == "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
):
    pass
    print("Already in correct directory")
else:
    os.chdir(
        "/fh/fast/bloom_j/computational_notebooks/blarsen/2023/Nipah_Malaysia_RBP_DMS/"
    )
    print("Setup in correct directory")
Setup in correct directory

For running interactively¶

In [6]:
if binding_vs_escape is None:
    altair_config = "data/custom_analyses_data/theme.py"
    nipah_config = "nipah_config.yaml"

    binding_data = "results/filtered_data/E2_binding_filtered.csv"

    HENV103_filter = "results/filtered_data/HENV103_escape_filtered.csv"
    HENV117_filter = "results/filtered_data/HENV117_escape_filtered.csv"
    HENV26_filter = "results/filtered_data/HENV26_escape_filtered.csv"
    HENV32_filter = "results/filtered_data/HENV32_escape_filtered.csv"
    m102_filter = "results/filtered_data/m102_escape_filtered.csv"
    nAH1_filter = "results/filtered_data/nAH1_escape_filtered.csv"

    # escape_bubble_plot = 'results/images/escape_bubble_plot.html'
    # bubble_1_mut_plot = 'results/images/escape_bubble_1_mut_plot.html'
    # overlap_escape_plot = 'results/images/overlap_escape_plot.html'

    # m102_heat = 'results/images/m102_heatmap.html'
    # HENV26_heat = 'results/images/HENV26_heatmap.html'
    # HENV32_heat = 'results/images/HENV32_heatmap.html'
    # nAH1_heat = 'results/images/nAH1_heatmap.html'
    # HENV117_heat = 'results/images/HENV117_heatmap.html'
    # HENV103_heat = 'results/images/HENV103_heatmap.html'
In [7]:
if altair_config:
    with open(altair_config, "r") as file:
        exec(file.read())

with open(nipah_config) as f:
    config = yaml.safe_load(f)

Make logo plots¶

Filtering parameters¶

In [8]:
# Make a dataframe with all the mutants with low entry scores for masking later in script
func_scores_E3 = pd.read_csv(
    "../Nipah_Malaysia_RBP_DMS/results/func_effects/averages/CHO_EFNB3_low_func_effects.csv"
)
func_scores_E3_low_effect = func_scores_E3[
    (func_scores_E3["effect"] < config["min_func_effect_for_ab"])
    & (func_scores_E3["times_seen"] > config["func_times_seen_cutoff"])
    & (func_scores_E3["site"] != 603)
    & (func_scores_E3["mutant"] != "-")
    & (func_scores_E3["mutant"] != "*")
]
display(func_scores_E3_low_effect)
site wildtype mutant effect effect_std times_seen n_selections
13 71 Q P -3.430 0.00000 6.714 7
34 72 N P -3.544 0.00000 7.857 7
39 72 N V -2.775 0.07748 7.571 7
55 73 Y P -3.356 0.16900 3.833 6
76 74 T P -3.194 0.12440 4.857 7
... ... ... ... ... ... ... ...
10768 597 I S -3.190 0.40170 3.286 7
10769 597 I T -3.494 0.02079 7.143 7
10772 597 I Y -2.367 0.66970 4.000 7
10774 598 P C -2.125 0.94490 5.143 7
10791 598 P W -3.169 0.00000 8.143 7

2765 rows × 7 columns

Read in filtered antibody escape files and combine.¶

In [9]:
HENV103 = pd.read_csv(HENV103_filter)
HENV117 = pd.read_csv(HENV117_filter)
HENV26 = pd.read_csv(HENV26_filter)
HENV32 = pd.read_csv(HENV32_filter)
m102 = pd.read_csv(m102_filter)
nAH1 = pd.read_csv(nAH1_filter)

# Combine all the individual filtered antibody escape files
combined_df = pd.concat([HENV103, HENV117, HENV26, HENV32, m102, nAH1])
combined_df = combined_df[
    [
        "site",
        "wildtype",
        "mutant",
        "mutation",
        "effect",
        "escape_median",
        "escape_std",
        "times_seen_ab",
        "show_site",
        "ab",
    ]
]
display(combined_df)

# Make a separate dataframe that only has the top sites
filtered_df = combined_df.query("show_site == True")
filtered_df = filtered_df[filtered_df["escape_median"] >= config["min_escape_cutoff"]]
display(filtered_df)
site wildtype mutant mutation effect escape_median escape_std times_seen_ab show_site ab
0 71 Q D Q71D -0.4981 -0.003238 0.05961 3.333 False HENV-103
1 71 Q E Q71E 0.3605 -0.029820 0.12310 3.333 False HENV-103
2 71 Q F Q71F -0.5317 0.012440 0.01263 2.333 False HENV-103
3 71 Q G Q71G -1.2470 0.040310 0.02698 3.333 False HENV-103
4 71 Q H Q71H -0.2558 0.014980 0.23130 2.667 False HENV-103
... ... ... ... ... ... ... ... ... ... ...
6913 602 T R T602R 0.4772 -0.016640 0.00485 7.333 False nAH1.3
6914 602 T S T602S 0.3906 0.161300 0.23030 3.667 False nAH1.3
6915 602 T V T602V 0.3780 0.131000 0.07729 6.000 False nAH1.3
6916 602 T W T602W 0.5438 0.071860 0.09132 6.333 False nAH1.3
6917 602 T Y T602Y 0.4982 0.246200 0.09855 6.667 False nAH1.3

41400 rows × 10 columns

site wildtype mutant mutation effect escape_median escape_std times_seen_ab show_site ab
906 154 K I K154I -0.93860 0.5414 0.33560 5.333 True HENV-103
912 154 K S K154S 0.23260 0.4880 0.31120 2.667 True HENV-103
914 154 K V K154V -0.77220 0.4362 0.32850 5.333 True HENV-103
1180 176 E F E176F 0.24160 0.5403 0.34270 5.000 True HENV-103
1185 176 E L E176L -0.09272 0.4601 0.27780 5.000 True HENV-103
... ... ... ... ... ... ... ... ... ... ...
5930 518 N M N518M 0.34070 0.9404 0.72310 5.667 True nAH1.3
5931 518 N P N518P -0.56450 0.9933 0.09195 5.333 True nAH1.3
5933 518 N R N518R 0.44010 0.7831 0.24450 7.667 True nAH1.3
5937 518 N W N518W 0.16300 0.8366 0.25700 6.000 True nAH1.3
5938 518 N Y N518Y -0.40690 1.5040 0.49130 5.333 True nAH1.3

517 rows × 10 columns

In [10]:
def identify_escape_sites(df, ab):
    subset = df[(df["ab"] == ab)]
    unique_sites = list(subset["site"].unique())
    return unique_sites


abs = ["HENV-26", "HENV-103", "HENV-32", "HENV-117", "m102.4", "nAH1.3"]
sites_dict = {}  # Create an empty dictionary to store the results

for ab in abs:
    sites_dict[ab] = identify_escape_sites(filtered_df, ab)

display(sites_dict)  # need site dict for later
{'HENV-26': [166, 171, 204, 257, 490, 491, 492, 494, 497, 501, 529, 530],
 'HENV-103': [154, 176, 205, 258, 259, 260, 264, 268, 274, 275, 277],
 'HENV-32': [154, 176, 199, 200, 205, 207, 268, 274, 275, 277, 534, 596],
 'HENV-117': [171, 204, 257, 555, 580, 582, 583, 586, 587, 588, 589],
 'm102.4': [172, 239, 270, 507, 542, 559, 577, 582, 586, 587, 589],
 'nAH1.3': [184, 185, 188, 189, 190, 447, 450, 468, 516, 517, 518]}

Plot bubble chart showing mAb escape for individual mutants by functional score for both E2 or E3¶

In [11]:
order_ab = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]


def generate_chart(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    chart = (
        alt.Chart(
            df,
            title=alt.Title(
                "Top Antibody Escape Mutations",
                subtitle="Hover over points to see escape at same site",
            ),
        )
        .mark_point(stroke="black")
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title="Antibody",
                axis=alt.Axis(labelAngle=-90, grid=False),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Top Escape",
                axis=alt.Axis(
                    grid=True, tickCount=4, values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),
            size=alt.Size(
                "escape_median", legend=alt.Legend(title="Mean Escape By Mutation")
            ),
            xOffset="random:Q",
            tooltip=[
                "site",
                "wildtype",
                "mutant",
                "ab",
                "effect",
                "escape_median",
                "escape_std",
            ],
            color=alt.Color("ab").legend(None),
            opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
            strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        )
        .transform_calculate(
            random="sqrt(-1*log(random()))*cos(2*PI*random())"
        )
        .properties(width=config["bubble_width"], height=config["bubble_height"])
        .add_params(variant_selector)
    )

    return chart


escape_bubble = generate_chart(filtered_df)
escape_bubble.display()
if mab_line_escape_plot is not None:
    escape_bubble.save(escape_bubble_plot)

Now summarize by number of mutations between wildtype and mutant codons¶

In [12]:
# Load in wt nucleotide sequence (which is different than the 'wt' sequence from Library as it was codon optimized)
niv_m_wt = str(
    Bio.SeqIO.read(
        "data/custom_analyses_data/alignments/wild_type_seq.fasta", "fasta"
    ).seq
)

codon_table = {
    "ATA": "I",
    "ATC": "I",
    "ATT": "I",
    "ATG": "M",
    "ACA": "T",
    "ACC": "T",
    "ACG": "T",
    "ACT": "T",
    "AAC": "N",
    "AAT": "N",
    "AAA": "K",
    "AAG": "K",
    "AGC": "S",
    "AGT": "S",
    "AGA": "R",
    "AGG": "R",
    "CTA": "L",
    "CTC": "L",
    "CTG": "L",
    "CTT": "L",
    "CCA": "P",
    "CCC": "P",
    "CCG": "P",
    "CCT": "P",
    "CAC": "H",
    "CAT": "H",
    "CAA": "Q",
    "CAG": "Q",
    "CGA": "R",
    "CGC": "R",
    "CGG": "R",
    "CGT": "R",
    "GTA": "V",
    "GTC": "V",
    "GTG": "V",
    "GTT": "V",
    "GCA": "A",
    "GCC": "A",
    "GCG": "A",
    "GCT": "A",
    "GAC": "D",
    "GAT": "D",
    "GAA": "E",
    "GAG": "E",
    "GGA": "G",
    "GGC": "G",
    "GGG": "G",
    "GGT": "G",
    "TCA": "S",
    "TCC": "S",
    "TCG": "S",
    "TCT": "S",
    "TTC": "F",
    "TTT": "F",
    "TTA": "L",
    "TTG": "L",
    "TAC": "Y",
    "TAT": "Y",
    "TAA": "*",
    "TAG": "*",
    "TGC": "C",
    "TGT": "C",
    "TGA": "*",
    "TGG": "W",
}


def find_closest_codon(wt_codon, mutant_aa):
    mutant_codons = [codon for codon, aa in codon_table.items() if aa == mutant_aa]
    min_mutations = 3  # Maximum mutations possible
    closest_codon = None
    for m_codon in mutant_codons:
        mutations = sum([1 for c1, c2 in zip(wt_codon, m_codon) if c1 != c2])
        if mutations < min_mutations:
            min_mutations = mutations
            closest_codon = m_codon
    return closest_codon, min_mutations


# Function to extract codon for a given site
def extract_codon(site):
    idx = (site - 1) * 3
    return niv_m_wt[idx : idx + 3]


def extract_codon_niv_b(site):
    idx = (site - 1) * 3
    return niv_m_wt[idx : idx + 3]


def apply_codon_to_df(df, extract_func):
    df["wt_codon"] = df["site"].apply(extract_func)
    df["closest_mutant_codon"] = df.apply(
        lambda row: find_closest_codon(row["wt_codon"], row["mutant"])[0], axis=1
    )
    df["min_mutations"] = df.apply(
        lambda row: find_closest_codon(row["wt_codon"], row["mutant"])[1], axis=1
    )
    return df


combined_df = apply_codon_to_df(combined_df, extract_codon)
filtered_df = apply_codon_to_df(filtered_df, extract_codon)
In [13]:
def generate_chart_all(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    radio = alt.binding_radio(
        options=[1, 2, 3], labels=["1", "2", "3"], name="Min Mutations:"
    )
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    slider = alt.binding_range(min=0.2, max=1.6, step=0.1, name="median_escape")
    selector = alt.param(name="SelectorName", value=0.2, bind=slider)

    chart = (
        alt.Chart(
            df,
            title=alt.Title(
                "Antibody Escape Mutations",
                subtitle="Hover over points to see escape at same site",
            ),
        )
        .mark_point(filled=True, stroke="black")
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title="Antibody",
                axis=alt.Axis(labelAngle=-90, grid=False),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Top Escape",
                axis=alt.Axis(
                    grid=True, tickCount=4, values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),
            size=alt.Size(
                "escape_median", legend=alt.Legend(title="Mean Escape By Mutation")
            ),
            xOffset="random:Q",
            tooltip=[
                "site",
                "wildtype",
                "mutant",
                "ab",
                "effect",
                "escape_median",
                "escape_std",
            ],
            color=alt.Color("ab").legend(None),
            opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
            strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
        )
        .transform_calculate(
            random="sqrt(-1*log(random()))*cos(2*PI*random())"
            # random='random'
        )
        #.properties(width=config["bubble_width"], height=config["bubble_height"])
        .properties(width=200,height=250)
        .add_params(variant_selector, mutation_selector, selector)
        .transform_filter(
            (alt.datum.min_mutations == mutation_selector)
            & (alt.datum.escape_median > selector)
        )
    )

    return chart


#all_escape = generate_chart_all(combined_df.query("escape_median >= 0.2"))
all_escape = generate_chart_all(filtered_df)

all_escape.display()
In [14]:
# Make combined figure
combined_bubble_plots = (escape_bubble | all_escape)
combined_bubble_plots.display()
In [15]:
def plot_escape_and_mutations_away(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    radio = alt.binding_radio(options=[1, 2, 3], name="Min Mutations:")
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    chart = (
        alt.Chart(
            df,
            title=alt.Title(
                "Top Antibody Escape Mutations",
                subtitle="By # of nucleotide mutations away",
            ),
        )
        .mark_point(filled=True, stroke="black",strokeWidth=0.75,opacity=0.3)
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title='Antibody',
                axis=alt.Axis(labelAngle=-90, grid=False),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Escape Mutants",
                axis=alt.Axis(
                    grid=True, tickCount=4, values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),  # 'Q' denotes a quantitative variable
            size=alt.Size("escape_median", legend=alt.Legend(title="Escape of Mutant")),
            xOffset="random:Q",
            tooltip=["ab", "effect", "escape_median", "site", "mutant"],
            #opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
            #strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            color=alt.Color("ab:N").legend(None),
        )
        .transform_calculate(
            # random='random()'
            random="sqrt(-2*log(random()))*cos(2*PI*random())"
        )
        .properties(width=200, height=config["bubble_height"])
        .add_params(variant_selector, mutation_selector)
        .transform_filter((alt.datum.min_mutations == mutation_selector))
    )
    return chart


bubble_plot_1_mut_away = plot_escape_and_mutations_away(filtered_df)
bubble_plot_1_mut_away.display()
if mab_line_escape_plot is not None:
    bubble_plot_1_mut_away.save(bubble_1_mut_plot)
In [16]:
def plot_escape_and_mutations_away_summed(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=1
    )
    aggregated_df = df.groupby(['site','ab']).agg({
        'escape_median': 'sum',
        'wildtype': 'first',
        #'ab': 'first',
        'effect': 'median',
    }).reset_index()
    aggregated_df = aggregated_df.query('escape_median > 2')
    #display(aggregated_df)
    
    chart = (
        alt.Chart(
            aggregated_df,
        )
        .mark_point(filled=True, stroke="black",strokeWidth=0.75,opacity=0.3)
        .encode(
            x=alt.X(
                "ab:O",
                sort=order_ab,
                title='Antibody',
                axis=alt.Axis(labelAngle=-90, grid=True),
            ),
            y=alt.Y(
                "effect:Q",
                title="Cell Entry of Escape Mutants",
                axis=alt.Axis(
                    grid=True, tickCount=4, #values=[0.5, 0, -0.5, -1, -1.5, -2]
                ),
            ),  
            size=alt.Size("escape_median", legend=alt.Legend(title="Escape Score")),
            xOffset="random:Q",
            tooltip=["ab", "effect", "escape_median", "site"],
            #opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.4)),
            #strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            color=alt.Color("ab:N").legend(None),
        )
        .transform_calculate(
            # random='random()'
            random="sqrt(-2*log(random()))*cos(2*PI*random())"
        )
        .properties(width=150, height=config["bubble_height"])
        .add_params(variant_selector)
    )
    return chart


bubble_plot_1_mut_away = plot_escape_and_mutations_away_summed(combined_df)
bubble_plot_1_mut_away.display()
#if mab_line_escape_plot is not None:
#    bubble_plot_1_mut_away.save(bubble_1_mut_plot)
In [17]:
def find_overlapping_escape(df):
    slider = alt.binding_range(
        min=config["min_func_effect_for_ab"], max=0, step=0.25, name="effect"
    )
    selector = alt.param(name="SelectorName", value=-4, bind=slider)

    radio = alt.binding_radio(options=[1, 2, 3], name="Min Mutations:")
    mutation_selector = alt.param(name="MutationSelector", value=1, bind=radio)

    df_filtered = df
    # Group by 'site' and 'mutant', count the unique 'ab' values for each group
    grouped = df_filtered.groupby(["site", "mutant"])["ab"].nunique().reset_index()

    # Filter groups where the count of unique 'ab' values is at least 2
    result = grouped[grouped["ab"] >= 2]

    # Merge the result with the original dataframe to get the full rows
    df_result = pd.merge(df, result[["site", "mutant"]], on=["site", "mutant"])
    df_result["mutation_number"] = (
        df_result["mutation"].str.extract("(\d+)").astype(int)
    )
    base = (
        (
            alt.Chart(df_result, title=alt.Title("Shared antibody escape mutations"))
            .mark_rect()
            .encode(
                x=alt.X(
                    "mutation:O",
                    title="Site",
                    sort=alt.EncodingSortField(field="mutation_number"),
                    axis=alt.Axis(labelAngle=-90, grid=False),
                ),
                y=alt.Y(
                    "ab:O", title="Mutant", sort=order_ab, axis=alt.Axis(grid=False)
                ),  # Apply custom sort order here
                color="escape_median",
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "min_mutations",
                ],
            )
        )
        .properties(height=200, width=400)
        .add_params(selector, mutation_selector)
        .transform_filter(
            (alt.datum.effect >= selector)
            & (alt.datum.min_mutations == mutation_selector)
        )
    )
    return base


overlap_escape = find_overlapping_escape(combined_df.query('escape_median > 0.4'))
overlap_escape.display()
if mab_line_escape_plot is not None:
    overlap_escape.save(overlap_escape_plot)

Line plots of escape¶

In [18]:
display(combined_df[combined_df['site'] == 501])
site wildtype mutant mutation effect escape_median escape_std times_seen_ab show_site ab wt_codon closest_mutant_codon min_mutations
5740 501 E A E501A -1.2510 -0.004303 0.06015 5.667 False HENV-103 GAG GCG 1
5741 501 E F E501F -0.7289 0.059430 0.04503 4.000 False HENV-103 GAG None 3
5742 501 E H E501H -1.2240 0.145000 0.20050 6.333 False HENV-103 GAG CAC 2
5743 501 E M E501M -1.8880 0.064130 0.07890 5.000 False HENV-103 GAG ATG 2
5744 501 E N E501N -0.5469 0.110300 0.11350 6.000 False HENV-103 GAG AAC 2
5745 501 E Q E501Q -0.1457 0.189500 0.11340 6.000 False HENV-103 GAG CAG 1
5746 501 E R E501R -1.8150 -0.003218 0.12690 4.333 False HENV-103 GAG AGG 2
5747 501 E S E501S -1.3720 -0.008346 0.03921 6.667 False HENV-103 GAG TCG 2
5748 501 E W E501W -1.5340 -0.010080 0.11260 6.333 False HENV-103 GAG TGG 2
5740 501 E A E501A -1.2510 -0.260900 0.08308 5.667 False HENV-117 GAG GCG 1
5741 501 E F E501F -0.7289 -0.156200 0.05349 4.000 False HENV-117 GAG None 3
5742 501 E H E501H -1.2240 -0.114300 0.04911 6.333 False HENV-117 GAG CAC 2
5743 501 E M E501M -1.8880 -0.227900 0.07708 5.000 False HENV-117 GAG ATG 2
5744 501 E N E501N -0.5469 -0.184700 0.16850 6.000 False HENV-117 GAG AAC 2
5745 501 E Q E501Q -0.1457 -0.075460 0.11920 6.000 False HENV-117 GAG CAG 1
5746 501 E R E501R -1.8150 -0.049160 0.04527 4.333 False HENV-117 GAG AGG 2
5747 501 E S E501S -1.3720 -0.230500 0.08288 6.667 False HENV-117 GAG TCG 2
5748 501 E W E501W -1.5340 -0.121700 0.22670 6.333 False HENV-117 GAG TGG 2
5695 501 E A E501A -1.2510 0.424800 0.26380 5.667 True HENV-26 GAG GCG 1
5696 501 E F E501F -0.7289 0.154400 0.02275 4.000 True HENV-26 GAG None 3
5697 501 E H E501H -1.2240 0.358000 0.04759 6.667 True HENV-26 GAG CAC 2
5698 501 E M E501M -1.8880 0.832500 0.53770 4.333 True HENV-26 GAG ATG 2
5699 501 E N E501N -0.5469 0.548800 0.20380 6.000 True HENV-26 GAG AAC 2
5700 501 E Q E501Q -0.1457 0.584000 0.19860 6.000 True HENV-26 GAG CAG 1
5701 501 E R E501R -1.8150 1.513000 1.63300 4.444 True HENV-26 GAG AGG 2
5702 501 E S E501S -1.3720 0.405700 0.27800 5.667 True HENV-26 GAG TCG 2
5703 501 E W E501W -1.5340 0.467900 0.07725 6.333 True HENV-26 GAG TGG 2
5759 501 E A E501A -1.2510 0.095320 0.15570 5.250 False HENV-32 GAG GCG 1
5760 501 E F E501F -0.7289 0.067580 0.05454 3.750 False HENV-32 GAG None 3
5761 501 E H E501H -1.2240 0.153900 0.19730 7.250 False HENV-32 GAG CAC 2
5762 501 E M E501M -1.8880 -0.039990 0.21370 4.000 False HENV-32 GAG ATG 2
5763 501 E N E501N -0.5469 0.205100 0.05605 6.250 False HENV-32 GAG AAC 2
5764 501 E Q E501Q -0.1457 0.161900 0.08877 5.750 False HENV-32 GAG CAG 1
5765 501 E R E501R -1.8150 0.148700 0.11470 4.500 False HENV-32 GAG AGG 2
5766 501 E S E501S -1.3720 0.203100 0.25200 6.250 False HENV-32 GAG TCG 2
5767 501 E W E501W -1.5340 0.105500 0.13840 5.750 False HENV-32 GAG TGG 2
5778 501 E A E501A -1.2510 -0.356500 0.12440 6.500 False m102.4 GAG GCG 1
5779 501 E F E501F -0.7289 -0.255600 0.07989 4.500 False m102.4 GAG None 3
5780 501 E H E501H -1.2240 -0.460300 0.24370 7.500 False m102.4 GAG CAC 2
5781 501 E M E501M -1.8880 -0.341600 0.23340 4.500 False m102.4 GAG ATG 2
5782 501 E N E501N -0.5469 -0.146500 0.36710 5.000 False m102.4 GAG AAC 2
5783 501 E Q E501Q -0.1457 -0.307600 0.09046 6.500 False m102.4 GAG CAG 1
5784 501 E R E501R -1.8150 -0.385900 0.38200 5.500 False m102.4 GAG AGG 2
5785 501 E S E501S -1.3720 -0.364500 0.14610 8.000 False m102.4 GAG TCG 2
5786 501 E W E501W -1.5340 -0.589100 0.24070 6.500 False m102.4 GAG TGG 2
5761 501 E A E501A -1.2510 -0.358200 0.27840 5.667 False nAH1.3 GAG GCG 1
5762 501 E F E501F -0.7289 -0.290100 0.10920 4.000 False nAH1.3 GAG None 3
5763 501 E H E501H -1.2240 -0.280000 0.10070 7.000 False nAH1.3 GAG CAC 2
5764 501 E M E501M -1.8880 -0.214700 0.09829 4.000 False nAH1.3 GAG ATG 2
5765 501 E N E501N -0.5469 -0.238300 0.10330 5.667 False nAH1.3 GAG AAC 2
5766 501 E Q E501Q -0.1457 -0.115800 0.15010 6.000 False nAH1.3 GAG CAG 1
5767 501 E R E501R -1.8150 -0.197000 0.39550 5.000 False nAH1.3 GAG AGG 2
5768 501 E S E501S -1.3720 -0.288100 0.09696 7.000 False nAH1.3 GAG TCG 2
5769 501 E W E501W -1.5340 -0.490200 0.15000 6.000 False nAH1.3 GAG TGG 2
In [19]:
def plot_line_escape(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, fields=["site"], value=0
    )
    # Group by 'site' and 'mutant', count the unique 'ab' values for each group
    summed = df.groupby(["site", "ab"])["escape_median"].sum().reset_index()
    
    # Need to add dummy row because site 500 has been completely masked out due to low entry scores and not showing up on x-axis
    new_row = pd.DataFrame({'site': [500], 'ab': ['HENV-103'], 'escape_median': [0]})
    summed = pd.concat([summed, new_row], ignore_index=True)
    
    empty_chart = []
    ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]
    for idx, ab in enumerate(ab_list):
        tmp_df = summed[summed["ab"] == ab]
        # color = '#1f4e79'
        if ab in ["m102.4", "HENV-26", "HENV-117"]:
            color = "#1f4e79"
        if ab in ["HENV-103", "HENV-32"]:
            color = "#ff7f0e"
        if ab in ["nAH1.3"]:
            color = "#2ca02c"

        # Conditionally set the x-axis labels and title for the last plot
        is_last_plot = idx == len(ab_list) - 1
        x_axis = alt.Axis(
            values=[100, 200, 300, 400, 500, 600],
            tickCount=6,
            labelAngle=-90,
            grid=True,
            labelExpr="datum.value % 100 === 0 ? datum.value : ''",
            title="Site" if is_last_plot else None,
            labels=is_last_plot,
        )  # Only show labels for the last plot
        base = (
            alt.Chart(tmp_df)
            .mark_line(size=1, color=color)
            .encode(
                x=alt.X("site:O", axis=x_axis),
                y=alt.Y("escape_median", title=f"{ab}", axis=alt.Axis(grid=True)),
            )
            .properties(
                width=config["large_line_width"], height=config["large_line_height"]
            )
        )
        point = (
            base.mark_point(color="black", size=10, filled=True)
            .encode(
                x=alt.X("site:O", axis=x_axis),
                y=alt.Y(
                    "escape_median",
                    title=f"{ab}",
                    axis=alt.Axis(
                        grid=True,
                    ),
                ),
                size=alt.condition(variant_selector, alt.value(100), alt.value(15)),
                color=alt.condition(
                    variant_selector, alt.value("black"), alt.value(color)
                ),
                tooltip=["site", "escape_median"],
            )
            .add_params(variant_selector)
        )
        chart = base + point
        empty_chart.append(chart)

    # Use configure_concat to adjust spacing between vertically concatenated plots
    combined_chart = (
        alt.vconcat(*empty_chart, spacing=1)
        .resolve_scale(y="independent", x="shared", color="independent")
        .properties(
            title=alt.Title(
                "Summed Antibody Escape by Site", subtitle="Colored by epitope"
            )
        )
    )

    return combined_chart


tmp_line = plot_line_escape(combined_df)
tmp_line.display()
if mab_line_escape_plot is not None:
    tmp_line.save(mab_line_escape_plot)

Now calculate atomic distances between escape sites and closest amino acid in heavy and light chains¶

In [20]:
def calculate_min_distances(pdb_path, source_chain_id, target_chain_ids, name):
    # Initialize the PDB parser and load the structure
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure("structure_id", pdb_path)

    source_chain = structure[0][source_chain_id]
    target_chains = [structure[0][chain_id] for chain_id in target_chain_ids]

    data = []

    for residueA in source_chain:
        if residueA.resname in ["HOH", "WAT", "IPA", "NAG"]:
            continue

        min_distance = float("inf")
        closest_residueB = None
        closest_chain_id = None
        residues_within_4 = 0

        for target_chain in target_chains:
            for residueB in target_chain:
                if residueB.resname in ["HOH", "WAT", "IPA"]:
                    continue

                # Check for residues within 4 angstroms
                is_within_4 = False
                for atomA in residueA:
                    for atomB in residueB:
                        distance = atomA - atomB
                        if distance < min_distance:
                            min_distance = distance
                            closest_residueB = residueB
                            closest_chain_id = target_chain.get_id()
                        if distance < 4:
                            is_within_4 = True
                if is_within_4:
                    residues_within_4 += 1

        data.append(
            {
                "wildtype": residueA.resname,
                "site": residueA.id[1],
                "chain": closest_chain_id,
                "residue": closest_residueB.id[1],
                "residue_name": closest_residueB.resname,
                "distance": min_distance,
                "residues_within_4": residues_within_4,
                "ab": name,
            }
        )

    # Convert data to pandas DataFrame
    df = pd.DataFrame(data)
    return df


def check_file(input_path, source_chain, target_chain, name, output_path):

    file_path = output_path

    if not os.path.exists(file_path):
        print(f"File {name} does not exist, running calculation")
        output_df = calculate_min_distances(
            input_path, source_chain, target_chain, name
        )
        print(f"done calculating for {file_path}")
        output_df.to_csv(output_path, index=False)
        return output_df
    else:
        print("File already exists,loading from disk")
        output_df = pd.read_csv(output_path)
        return output_df


pdb_path_26 = "data/custom_analyses_data/crystal_structures/6vy5.pdb"
source_chain_26 = "A"
target_chains_26 = ["H", "L"]
output_path_26 = "results/distances/df_HENV26_atomic_distances.csv"

pdb_path_32 = "data/custom_analyses_data/crystal_structures/6vy4.pdb"
source_chain_32 = "A"
target_chains_32 = ["H", "L"]
output_path_32 = "results/distances/df_HENV32_atomic_distances.csv"

pdb_path_nah = "data/custom_analyses_data/crystal_structures/7txz.pdb"
source_chain_nah = "A"
target_chains_nah = ["F", "E"]
output_path_nah = "results/distances/df_nAH_atomic_distances.csv"

pdb_path_m102 = "data/custom_analyses_data/crystal_structures/6cmg.pdb"
source_chain_m102 = "A"
target_chains_m102 = ["B", "C"]
output_path_m102 = "results/distances/df_m102_atomic_distances.csv"


df_HENV26 = check_file(
    pdb_path_26, source_chain_26, target_chains_26, "HENV-26", output_path_26
)
df_HENV32 = check_file(
    pdb_path_32, source_chain_32, target_chains_32, "HENV-32", output_path_32
)
df_nah = check_file(
    pdb_path_nah, source_chain_nah, target_chains_nah, "nAH1.3", output_path_nah
)
df_nah["chain"].replace(
    {"E": "H", "F": "L"}, inplace=True
)  # Fix naming so consistent heavy and light chain naming
df_m102 = check_file(
    pdb_path_m102, source_chain_m102, target_chains_m102, "m102.4", output_path_m102
)
df_m102["chain"].replace(
    {"C": "H", "B": "L"}, inplace=True
)  # Fix naming so consistent heavy and light chain naming
File already exists,loading from disk
File already exists,loading from disk
File already exists,loading from disk
File already exists,loading from disk
In [21]:
def find_close_mab_sites(df, name):
    unique_sites = df.query("distance <= 4")["site"].unique()
    mab_site_list = list(unique_sites)
    print(f"Close sites for mAb {name} are: {mab_site_list}")
    return mab_site_list


### First find RBP sites that are close to mAb residues
nah_close = find_close_mab_sites(df_nah, "nAH1.3")
HENV26_close = find_close_mab_sites(df_HENV26, "HENV-26")
HENV32_close = find_close_mab_sites(df_HENV32, "HENV-32")
m102_close = find_close_mab_sites(df_m102, "m102.4")

### Now combined the close residues AND the top escape sites identified previously
nah_combined_sites = sites_dict["nAH1.3"] + nah_close
HENV26_combined_sites = sites_dict["HENV-26"] + HENV26_close
HENV32_combined_sites = sites_dict["HENV-32"] + HENV32_close
m102_combined_sites = sites_dict["m102.4"] + m102_close
Close sites for mAb nAH1.3 are: [172, 183, 184, 185, 186, 187, 188, 190, 191, 358, 449, 450, 451, 472, 515, 516, 517, 518, 570]
Close sites for mAb HENV-26 are: [389, 401, 403, 404, 458, 488, 489, 490, 491, 492, 494, 497, 501, 504, 505, 506, 528, 529, 530, 531, 532, 533, 555, 556, 557, 581, 586]
Close sites for mAb HENV-32 are: [196, 199, 200, 201, 202, 203, 205, 206, 207, 210, 254, 256, 258, 260, 262, 263, 264, 266, 553]
Close sites for mAb m102.4 are: [239, 240, 241, 242, 305, 458, 488, 489, 490, 504, 505, 506, 507, 530, 532, 533, 555, 557, 559, 579, 580, 581, 588]
In [22]:
def make_distance(df):
    subset_df = df[["site", "distance"]].copy()
    subset_df["mutant"] = "distance"
    subset_df["wildtype"] = ""
    subset_df["effect"] = "escape_median"
    subset_df.rename(columns={"distance": "value"}, inplace=True)
    return subset_df


distance_nah_df = make_distance(df_nah)
distance_26_df = make_distance(df_HENV26)
distance_32_df = make_distance(df_HENV32)
distance_m102_df = make_distance(df_m102)

display(distance_m102_df)
site value mutant wildtype effect
0 176 35.044434 distance escape_median
1 177 31.866014 distance escape_median
2 178 27.842815 distance escape_median
3 179 28.777035 distance escape_median
4 180 28.332012 distance escape_median
... ... ... ... ... ...
423 599 30.802711 distance escape_median
424 600 28.920950 distance escape_median
425 601 27.248772 distance escape_median
426 602 29.020868 distance escape_median
427 603 27.945621 distance escape_median

428 rows × 5 columns

Prepare dataframes for heatmaps¶

In [23]:
def make_empty_df_with_distance(ab, distance_file):
    # print(ab)
    sites = range(71, 603)
    amino_acids = [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
    ]
    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]
    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(
        empty_df, combined_df.query(f'ab == "{ab}"'), on=["site", "mutant"], how="left"
    )
    df_melted = all_sites_df.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_filtered = func_scores_E3_low_effect.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["effect"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted, df_filtered, distance_file], ignore_index=True)
    df_test["ab"] = ab
    return df_test


empty_df_m102 = make_empty_df_with_distance("m102.4", distance_m102_df)
empty_df_HENV26 = make_empty_df_with_distance("HENV-26", distance_26_df)
empty_df_HENV32 = make_empty_df_with_distance("HENV-32", distance_32_df)
empty_df_nah = make_empty_df_with_distance("nAH1.3", distance_nah_df)


def make_empty_df(ab):
    sites = range(71, 603)
    amino_acids = [
        "A",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "K",
        "L",
        "M",
        "N",
        "P",
        "Q",
        "R",
        "S",
        "T",
        "V",
        "W",
        "Y",
    ]
    # Create the combination of each site with each amino acid
    data = [{"site": site, "mutant": aa} for site in sites for aa in amino_acids]
    # Create the DataFrame
    empty_df = pd.DataFrame(data)
    all_sites_df = pd.merge(
        empty_df, combined_df.query(f'ab == "{ab}"'), on=["site", "mutant"], how="left"
    )
    df_melted = all_sites_df.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_filtered = func_scores_E3_low_effect.melt(
        id_vars=["site", "mutant", "wildtype"],
        value_vars=["effect"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted, df_filtered], ignore_index=True)
    df_test["ab"] = ab
    return df_test


empty_df_HENV117 = make_empty_df("HENV-117")
empty_df_HENV103 = make_empty_df("HENV-103")

combined_ab = pd.concat(
    [
        empty_df_m102,
        empty_df_HENV26,
        empty_df_HENV32,
        empty_df_nah,
        empty_df_HENV117,
        empty_df_HENV103,
    ]
)
display(combined_ab)
site mutant wildtype effect value ab
0 71 A NaN escape_median NaN m102.4
1 71 C Q escape_median 0.290000 m102.4
2 71 D Q escape_median -0.021530 m102.4
3 71 E Q escape_median 0.017950 m102.4
4 71 F Q escape_median 0.004024 m102.4
... ... ... ... ... ... ...
13400 597 S I effect -3.190000 HENV-103
13401 597 T I effect -3.494000 HENV-103
13402 597 Y I effect -2.367000 HENV-103
13403 598 C P effect -2.125000 HENV-103
13404 598 W P effect -3.169000 HENV-103

82160 rows × 6 columns

In [24]:
def plot_distance_only(df, trigger):
    custom_order = [
        "distance",
        "R",
        "K",
        "H",
        "D",
        "E",
        "Q",
        "N",
        "S",
        "T",
        "Y",
        "W",
        "F",
        "A",
        "I",
        "L",
        "M",
        "V",
        "G",
        "P",
        "C",
    ]
    all_residues = range(71, 603)
    final_df = df
    final_df = final_df.sort_values(
        "site"
    )  # Sort the dataframe by 'site' to ensure that duplicates are detected correctly.
    sort_order = {
        mutant: i for i, mutant in enumerate(custom_order)
    }  # Create a dictionary that maps each mutant to its sort rank based on the custom order
    final_df["mutant_rank"] = final_df["mutant"].map(
        sort_order
    )  # Map the 'mutant' column to these ranks

    final_df = final_df.sort_values(
        "mutant_rank"
    )  # Now sort the dataframe by this rank
    final_df = final_df.drop(
        columns=["mutant_rank"]
    )  # Drop the 'mutant_rank' column as it is no longer needed after sorting
    sites = sorted(final_df["site"].unique(), key=lambda x: float(x))
    ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]
    empty_chart = []  # setup collection for charts
    for idx, ab in enumerate(ab_list):
        tmp_df = final_df[final_df["ab"] == ab]
        if ab == "m102.4":
            site_subset = m102_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == "HENV-26":
            site_subset = HENV26_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == "HENV-32":
            site_subset = HENV32_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')
        if ab == "HENV-103":
            site_subset = sites_dict["HENV-103"]
            # legend_conditional = alt.Legend(title=None)
        if ab == "HENV-117":
            site_subset = sites_dict["HENV-117"]
            # legend_conditional = alt.Legend(title=None)
        if ab == "nAH1.3":
            site_subset = nah_combined_sites
            # legend_conditional = alt.Legend(title='Distance to mAb')

        # select which sites you will show
        if trigger == True:
            tmp_df = tmp_df[tmp_df["site"].isin(site_subset)]
            x_axis = alt.Axis(
                labelAngle=-90,
                # labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site",
            )
        else:
            tmp_df = tmp_df[tmp_df["site"].isin(all_residues)]

            # Conditionally set the x-axis labels and title for the last plot
            is_last_plot = idx == len(ab_list) - 1
            x_axis = alt.Axis(
                labelAngle=-90,
                labelExpr="datum.value % 10 === 0 ? datum.value : ''",
                title="Site" if is_last_plot else None,
                labels=True,
            )  # Only show labels for the last plot

        # Prepare the color scales separately for distance and effects
        # Filter out 'distance' values before creating the effect heatmap
        effect_df = tmp_df[
            (tmp_df["mutant"] != "distance") & (tmp_df["effect"] != "effect")
        ]
        max_color = effect_df["value"].max()
        min_color = effect_df["value"].min()

        # Adjust color scheme for abs with little sensitizing mutations
        if min_color > -1:
            min_color = min_color - 1

        # Prepare the color scale for effects, Altair will automatically determine the domain
        color_scale_escape = alt.Scale(
            scheme="redblue", domainMid=0, domain=[min_color, max_color]
        )
        color_scale_entropy = alt.Scale(scheme="greens", domain=[0, 15], reverse=True)

        strokewidth_size = 0.25

        unique_wildtypes_df = tmp_df.drop_duplicates(subset=["site", "wildtype"])

        # The chart for the heatmap
        base = (
            alt.Chart(tmp_df, title=f"{ab}")
            .encode(
                x=alt.X("site:O", title="Site", sort=sites, axis=x_axis),
                y=alt.Y(
                    "mutant",
                    title="Amino Acid",
                    sort=alt.EncodingSortField(field="sort_order", order="ascending"),
                    axis=alt.Axis(grid=False),
                ),  # Apply custom sort order here
                tooltip=["site", "wildtype", "mutant", "value"],
            )
            .properties(width=alt.Step(10), height=alt.Step(11))
        )
        # Heatmap for distance
        chart_empty = (
            base.mark_rect(color="#e6e7e8")
            .encode()
            .transform_filter(alt.datum.effect == "escape_median")
        )
        # Heatmap for effect
        chart_effect = (
            base.mark_rect(stroke="black", strokeWidth=strokewidth_size)
            .encode(
                color=alt.condition(
                    'datum.mutant != "distance"',
                    alt.Color(
                        "value:Q",
                        scale=color_scale_escape,
                        legend=alt.Legend(title=f"{ab} Escape"),
                    ),
                    alt.value("transparent"),
                ),
            )
            .transform_filter(alt.datum.effect == "escape_median")
        )

        # Heatmap for distance
        if ab in ["m102.4", "HENV-26", "HENV-32", "nAH1.3"]:
            chart_distance = (
                base.mark_rect()
                .encode(
                    color=alt.condition(
                        'datum.mutant == "distance"',
                        alt.Color(
                            "value:Q",
                            scale=color_scale_entropy,
                            legend=alt.Legend(title="Distance to mAb"),
                        ),
                        alt.value("transparent"),
                    )
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
        else:
            chart_distance = (
                base.mark_rect(color="transparent")
                .encode(
                    # color=alt.Color('white'),
                    # alt.Color('value:Q', scale=color_scale_entropy,legend=alt.Legend(title='Distance to mAb')),
                    # alt.value('transparent'))
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
        # Heatmap for distance
        chart_filtered = (
            base.mark_rect(
                color="#939598", stroke="black", strokeWidth=strokewidth_size
            )
            .encode()
            .transform_filter(alt.datum.effect == "effect")
        )

        # The layer for the wildtype boxes
        wildtype_layer_box = (
            alt.Chart(unique_wildtypes_df)
            .mark_rect(color="white", stroke="black", strokeWidth=strokewidth_size)
            .encode(
                x=alt.X("site:O", sort=sites),
                y=alt.Y(
                    "wildtype",
                    sort=alt.EncodingSortField(field="sort_order", order="ascending"),
                ),
                opacity=alt.value(1),
            )
            .transform_filter(
                (alt.datum.wildtype != "")
                & (alt.datum.wildtype != None)
                & (alt.datum.value != None)
            )
        )
        # The layer for the wildtype amino acids
        wildtype_layer = (
            alt.Chart(unique_wildtypes_df)
            .mark_text(color="black", text="X", size=8)
            .encode(
                x=alt.X("site:O", sort=sites),
                y=alt.Y(
                    "wildtype",
                    sort=alt.EncodingSortField(field="sort_order", order="ascending"),
                ),
                opacity=alt.value(1),
            )
            .transform_filter(
                (alt.datum.wildtype != "")
                & (alt.datum.wildtype != None)
                & (alt.datum.value != None)
            )
        )

        # Combine the heatmap layer with the wildtype layer
        chart = alt.layer(
            chart_empty,
            chart_effect,
            chart_distance,
            chart_filtered,
            wildtype_layer_box,
            wildtype_layer,
        ).resolve_scale(color="independent")
        empty_chart.append(chart)
    combined_chart = (
        alt.vconcat(*empty_chart, spacing=1)
        .resolve_scale(y="shared", x="independent", color="independent")
        .configure_title(
            anchor="start",  # Aligns the title to the left ('middle' for center, 'end' for right)
            offset=10,  # Adjusts the distance of the title from the chart
            orient="top",  # Positions the title at the top; use 'bottom' to position at the bottom
        )
    )
    return combined_chart


mab_plot = plot_distance_only(combined_ab, True)
mab_plot.display()
if mab_line_escape_plot is not None:
    mab_plot.save(mab_plot_top)

Make full antibody escape heatmaps¶

In [25]:
mab_all = plot_distance_only(combined_ab, False)
mab_all.display()
if mab_line_escape_plot is not None:
    mab_all.save(mab_plot_all)

Now make heatmaps of antibody escape versus Ephrin Binding¶

First prepare data:

In [26]:
bind_df = pd.read_csv(binding_data)
binding_df = bind_df.groupby("site")["binding_median"].median().reset_index()


def make_empty_binding():
    sites = range(71, 603)
    data = [{"site": site} for site in sites]
    empty_df = pd.DataFrame(data)
    empty_df = pd.merge(empty_df, binding_df, on="site", how="left")
    empty_df = empty_df.rename(columns={"binding_median": "value"})
    empty_df["effect"] = "escape_median"
    empty_df["ab"] = "Ephrin-B2 binding"
    return empty_df


binding_empty = make_empty_binding()

escape_df = combined_df.groupby(["ab", "site"])["escape_median"].median().reset_index()


def make_empty_df(ab):
    sites = range(71, 603)
    data = [{"site": site} for site in sites]

    # Create the DataFrame
    empty_df = pd.DataFrame(data)

    all_sites_df = pd.merge(
        empty_df, escape_df.query(f'ab == "{ab}"'), on=["site"], how="left"
    )

    df_melted = all_sites_df.melt(
        id_vars=["site"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted], ignore_index=True)
    df_test["ab"] = ab
    return df_test


ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]
# ab_list = ['HENV-32']

empty = []
for ab in ab_list:
    tmp_df = make_empty_df(ab)
    empty.append(tmp_df)
all_empties_df = pd.concat(empty, ignore_index=True)
all_empties_df = pd.concat([all_empties_df, binding_empty])
display(all_empties_df)
site effect value ab
0 71 escape_median 0.011995 m102.4
1 72 escape_median 0.037640 m102.4
2 73 escape_median 0.042650 m102.4
3 74 escape_median -0.039695 m102.4
4 75 escape_median -0.017860 m102.4
... ... ... ... ...
527 598 escape_median 0.158000 Ephrin-B2 binding
528 599 escape_median -0.214000 Ephrin-B2 binding
529 600 escape_median -0.037000 Ephrin-B2 binding
530 601 escape_median 0.240000 Ephrin-B2 binding
531 602 escape_median -0.053500 Ephrin-B2 binding

3724 rows × 4 columns

In [27]:
def make_heatmap_with_binding(df):
    # Define the custom sort order directly in the encoding
    sort_order = [
        "NiV Polymorphism",
        "Ephrin-B2 binding",
        "m102.4",
        "HENV-26",
        "HENV-117",
        "HENV-103",
        "HENV-32",
        "nAH1.3",
    ]
    full_ranges = [
        list(range(start, end))
        for start, end in [(71, 181), (181, 291), (291, 401), (401, 511), (511, 603)]
    ]

    # container to hold the charts
    charts = []
    color_scale_effect = alt.Scale(scheme="redblue", domainMid=0)
    color_scale_binding = alt.Scale(scheme="redblue", domainMid=0)

    for idx, subset in enumerate(full_ranges):
        subset_df = df[df["site"].isin(subset)]  # for the wrapping of sites
        is_last_plot = idx == len(full_ranges) - 1
        x_axis = alt.Axis(
            labelAngle=-90,
            labelExpr="datum.value % 10 === 0 ? datum.value : ''",
            title="Site" if is_last_plot else None,
            labels=True,
        )  # Only show labels for the last plot

        effect_legend = (
            alt.Legend(title="Antibody Escape") if is_last_plot else None
        )  # ,direction='horizontal',gradientLength=50,titleAnchor='middle',tickCount=3,labelAlign='center')
        binding_legend = (
            alt.Legend(title="Henipavirus Entropy") if is_last_plot else None
        )  # ,direction='horizontal',gradientLength=50,titleAnchor='middle',labelAlign='center')
        print(is_last_plot)
        print(effect_legend)
        base = (
            alt.Chart(subset_df)
            .encode(
                x=alt.X("site:O", title="Site", axis=x_axis),
                y=alt.Y(
                    "ab", title=None, sort=sort_order, axis=alt.Axis(grid=False)
                ),  # Correctly apply custom sort order
                tooltip=["site", "value"],
            )
            .properties(width=alt.Step(10), height=alt.Step(11))
        )

        # Define the chart for empty cells
        chart_empty = base.mark_rect(color="#e6e7e8").transform_filter(
            alt.datum.effect == "escape_median"
        )

        # Define the chart for cells with effect
        chart_effect = (
            base.mark_rect(stroke="black", strokeWidth=0.25)
            .encode(
                color=alt.condition(
                    'datum.effect == "escape_median"',
                    alt.Color(
                        "value:Q", scale=color_scale_effect, legend=effect_legend
                    ),  # Define a color scale
                    alt.value("transparent"),
                )
            )
            .transform_filter(alt.datum.effect == "escape_median")
        )

        chart_binding = (
            base.mark_rect(strokeWidth=1.1)
            .encode(
                stroke=alt.value("value"),
                color=alt.condition(
                    'datum.effect == "escape_median"',
                    alt.Color(
                        "value:Q", scale=color_scale_binding, legend=binding_legend
                    ),
                    alt.value("transparent"),
                ),
            )
            .transform_filter(alt.datum.ab == "Ephrin-B2 binding")
        )

        chart_poly = (
            base.mark_rect(color="black")
            .encode()
            .transform_filter(alt.datum.ab == "NiV Polymorphism")
        )
        # Layer the charts using `layer` instead of `+`
        chart = alt.layer(
            chart_empty, chart_effect, chart_binding, chart_poly
        )  # .resolve_scale(color='shared')
        charts.append(chart)
    combined_chart = alt.vconcat(
        *charts, spacing=5, title="Heatmap of median mAb escape and Ephrin-B2 binding"
    )  

    return combined_chart


# Assuming `all_empties_df` is your DataFrame and already defined
chart = make_heatmap_with_binding(all_empties_df)
chart.display()
if mab_line_escape_plot is not None:
    chart.save(aggregate_mab_and_binding)
False
None
False
None
False
None
False
None
True
Legend({
  title: 'Antibody Escape'
})

Now show heatmap with nipah polymorphisms¶

In [28]:
def make_contact():
    df = pd.DataFrame({"site": niv_poly, "contact": [0.0] * len(niv_poly)})
    df = df[["site", "contact"]]
    # df['mutant'] = 'contact'
    df["ab"] = "NiV Polymorphism"
    df["effect"] = "median_escape"
    df.rename(columns={"contact": "value"}, inplace=True)
    return df


niv_poly = config['nipah_poly']
contact_df = make_contact()

bind_df = pd.read_csv("results/filtered_data/E2_binding_filtered.csv")
binding_df = bind_df.groupby("site")["binding_median"].max().reset_index()


def make_empty_binding():
    sites = range(71, 603)
    data = [{"site": site} for site in sites]
    empty_df = pd.DataFrame(data)
    empty_df = pd.merge(empty_df, binding_df, on="site", how="left")
    empty_df = empty_df.rename(columns={"binding_median": "value"})
    empty_df["effect"] = "escape_median"
    empty_df["ab"] = "Ephrin-B2 binding"
    return empty_df


binding_empty = make_empty_binding()

escape_df = combined_df.groupby(["ab", "site"])["escape_median"].max().reset_index()


def make_empty_df(ab):
    sites = range(71, 603)
    data = [{"site": site} for site in sites]

    # Create the DataFrame
    empty_df = pd.DataFrame(data)

    all_sites_df = pd.merge(
        empty_df, escape_df.query(f'ab == "{ab}"'), on=["site"], how="left"
    )

    df_melted = all_sites_df.melt(
        id_vars=["site"],
        value_vars=["escape_median"],
        var_name="effect",
        value_name="value",
    )

    df_test = pd.concat([df_melted], ignore_index=True)
    df_test["ab"] = ab
    return df_test


ab_list = ["m102.4", "HENV-26", "HENV-117", "HENV-103", "HENV-32", "nAH1.3"]

empty = []
for ab in ab_list:
    tmp_df = make_empty_df(ab)
    empty.append(tmp_df)
all_empties_df = pd.concat(empty, ignore_index=True)
all_empties_df = pd.concat([all_empties_df, contact_df])
display(all_empties_df)
site effect value ab
0 71 escape_median 0.42320 m102.4
1 72 escape_median 0.09214 m102.4
2 73 escape_median 0.33640 m102.4
3 74 escape_median 0.06910 m102.4
4 75 escape_median 0.11410 m102.4
... ... ... ... ...
30 478 median_escape 0.00000 NiV Polymorphism
31 481 median_escape 0.00000 NiV Polymorphism
32 498 median_escape 0.00000 NiV Polymorphism
33 502 median_escape 0.00000 NiV Polymorphism
34 545 median_escape 0.00000 NiV Polymorphism

3227 rows × 4 columns

In [29]:
def make_heatmap_with_polymorphisms(df):
    # Define the custom sort order directly in the encoding
    sort_order = [
        "NiV Polymorphism",
        "m102.4",
        "HENV-26",
        "HENV-117",
        "HENV-103",
        "HENV-32",
        "nAH1.3",
    ]
    # full_ranges = [list(range(start, end)) for start, end in [(71, 204), (204, 337), (337, 470), (470, 603)]]
    full_ranges = [
        list(range(start, end))
        for start, end in [(71, 181), (181, 291), (291, 401), (401, 511), (511, 603)]
    ]

    # container to hold the charts
    charts = []
    color_scale_effect = alt.Scale(scheme="redblue", domainMid=0, domain=[0, 2])
    color_scale_binding = alt.Scale(scheme="redblue", domainMid=0, domain=[-5, 2])

    # Flags for showing the legend only the first time
    effect_legend_added = True
    binding_legend_added = True
    for idx, subset in enumerate(full_ranges):
        subset_df = df[df["site"].isin(subset)]  # for the wrapping of sites
        is_last_plot = idx == len(full_ranges) - 1
        x_axis = alt.Axis(
            labelAngle=-90,
            labelExpr="datum.value % 10 === 0 ? datum.value : ''",
            title="Site" if is_last_plot else None,
            labels=True,
        )  # Only show labels for the last plot

        base = (
            alt.Chart(subset_df)
            .encode(
                x=alt.X("site:O", title="Site", axis=x_axis),
                y=alt.Y(
                    "ab", title=None, sort=sort_order, axis=alt.Axis(grid=False)
                ),  # Correctly apply custom sort order
                tooltip=["site", alt.Tooltip("value", format=".2f")],
            )
            .properties(width=alt.Step(10), height=alt.Step(11))
        )

        # Define the chart for empty cells
        chart_empty = base.mark_rect(color="#e6e7e8").transform_filter(
            alt.datum.effect == "escape_median"
        )
        if not effect_legend_added:
            # Define the chart for cells with effect
            chart_effect = (
                base.mark_rect(stroke="black", strokeWidth=0.25)
                .encode(
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color(
                            "value:Q", scale=color_scale_effect
                        ),  # Define a color scale
                        alt.value("transparent"),
                    )
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
            effect_legend_added = True
        else:
            # Define the chart for cells with effect
            chart_effect = (
                base.mark_rect(stroke="black", strokeWidth=0.25)
                .encode(
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color(
                            "value:Q", scale=color_scale_effect, legend=None
                        ),  # Define a color scale
                        alt.value("transparent"),
                    )
                )
                .transform_filter(alt.datum.effect == "escape_median")
            )
        if not binding_legend_added:
            chart_binding = (
                base.mark_rect(strokeWidth=1.1)
                .encode(
                    stroke=alt.value("value"),
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color("value:Q", scale=color_scale_binding),
                        alt.value("transparent"),
                    ),
                )
                .transform_filter(alt.datum.ab == "Ephrin-B2 binding")
            )
            binding_legend_added = True
        else:
            chart_binding = (
                base.mark_rect(strokeWidth=1.1)
                .encode(
                    stroke=alt.value("value"),
                    color=alt.condition(
                        'datum.effect == "escape_median"',
                        alt.Color("value:Q", scale=color_scale_binding, legend=None),
                        alt.value("transparent"),
                    ),
                )
                .transform_filter(alt.datum.ab == "Ephrin-B2 binding")
            )

        chart_poly = (
            base.mark_rect(color="black")
            .encode()
            .transform_filter(alt.datum.ab == "NiV Polymorphism")
        )
        # Layer the charts using `layer` instead of `+`
        chart = alt.layer(chart_empty, chart_effect, chart_poly).resolve_scale(
            color="independent"
        )
        charts.append(chart)
    combined_chart = alt.vconcat(
        *charts, spacing=5, title="Heatmap of max mAb escape and Nipah Polymorphisms"
    ).resolve_scale(y="shared", x="independent", color="shared")

    return combined_chart


# Assuming `all_empties_df` is your DataFrame and already defined
chart = make_heatmap_with_polymorphisms(all_empties_df)
chart.display()
if mab_line_escape_plot is not None:
    chart.save(aggregate_mab_and_niv_polymorphism)
In [30]:
def make_polymorphism_escape(df):
    df['is_poly'] = df['site'].isin(config['nipah_poly'])
    aggregated_df = df.groupby(['site', 'ab']).agg({
        'escape_median': 'sum',
        'wildtype': 'first',
        'is_poly': 'first'
    }).reset_index()
    aggregated_df['is_poly'] = aggregated_df['is_poly'].map({True: 'Polymorphic', False: 'Conserved'})
    empty_chart = []
    antibodies = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
    for index, ab in enumerate(antibodies):
        tmp_df = aggregated_df[aggregated_df['ab'] == ab]
        # Conditionally set the y-axis title for only the first chart
        y_axis_title = 'Summed Escape' if index == 0 else None
        
        base = alt.Chart(tmp_df).mark_point(size=25, opacity=0.2, filled=True,strokeWidth=0.5,stroke='black').encode(
            x=alt.X("is_poly", title=None, axis=alt.Axis(labelAngle=-90, grid=False)),
            xOffset="random:Q",
            y=alt.Y("escape_median", title=y_axis_title, axis=alt.Axis(grid=True, tickCount=2)),  
            tooltip=["site"],
            color=alt.Color('is_poly',legend=None),
        ).transform_calculate(
            random="sqrt(-2*log(random()))*cos(2*PI*random())"
        ).properties(
            title={
                "text": f'{ab}',
                "fontSize": 10,  # Adjust font size as needed
                "align": "center",
                'anchor': "middle"
            },
            width=50,
        )
        empty_chart.append(base)
    combined_chart = alt.hconcat(*empty_chart, spacing=10).resolve_scale(y="independent", color="shared")

    return combined_chart





chart = make_polymorphism_escape(combined_df)
chart.display()
In [31]:
def make_polymorphism_escape(df):
    df['is_poly'] = df['site'].isin(niv_poly)
    aggregated_df = df.groupby(['site']).agg({
        'escape_median': 'median',
        'wildtype': 'first',
        'is_poly': 'first'
    }).reset_index()
    empty_chart = []
    #antibodies = ['m102.4', 'HENV-26', 'HENV-117', 'HENV-103', 'HENV-32', 'nAH1.3']
    #for index, ab in enumerate(antibodies):
        #tmp_df = aggregated_df[aggregated_df['ab'] == ab]
        # Conditionally set the y-axis title for only the first chart
        #y_axis_title = 'Summed Escape' if index == 0 else None
        
    base = alt.Chart(aggregated_df).mark_point(size=25, opacity=0.2, filled=True,strokeWidth=0.5,stroke='black').encode(
        x=alt.X("is_poly", title=None, axis=alt.Axis(labelAngle=-90, grid=False)),
        xOffset="random:Q",
        y=alt.Y("escape_median", title=None, axis=alt.Axis(grid=True, tickCount=5)),  
        tooltip=["site"],
        color=alt.Color('is_poly',title='Polymorphic Site'),
    ).transform_calculate(
        random="sqrt(-2*log(random()))*cos(2*PI*random())"
    ).properties(
        width=50,
    )
#        empty_chart.append(base)
#    combined_chart = alt.hconcat(*empty_chart, spacing=10).resolve_scale(y="independent", color="shared")

    return base





chart = make_polymorphism_escape(combined_df)
chart.display()

Make plots comparing escape with binding to see if escape sites do so by increasing binding¶

In [32]:
new_merged_df = pd.merge(
    combined_df,
    bind_df[["site", "wildtype", "mutant", "binding_median"]],
    on=["site", "wildtype", "mutant"],
    how="left",
)
new_merged_df = new_merged_df.drop(
    columns=[
        "mutation",
        "escape_std",
        "times_seen_ab",
        "show_site",
        "wt_codon",
        "closest_mutant_codon",
        "min_mutations",
    ]
)
new_merged_df = new_merged_df.round(2)

ab_list1 = ["m102.4", "HENV-26", "HENV-117"]
ab_list2 = ["HENV-103", "HENV-32"]
ab_list3 = ["nAH1.3"]


def plot_escape_vs_binding(df):
    variant_selector = alt.selection_point(
        on="mouseover", empty=False, nearest=True, fields=["site"], value=1
    )
    empty_chart1 = []
    for ab in ab_list1:
        tmp_df = df[df["ab"] == ab]
        base = (
            alt.Chart(tmp_df, title=alt.Title(f"{ab}", anchor="middle"))
            .mark_point(
                filled=True, size=15, color="#1f4e79", opacity=0.15, stroke="black"
            )
            .encode(
                alt.X(
                    "binding_median",
                    title="EFNB2 Binding",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                alt.Y(
                    "escape_median",
                    title="Antibody Escape",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "binding_median",
                ],
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            )
        )
        empty_chart1.append(base)
    combined_chart1 = alt.hconcat(*empty_chart1, spacing=5).resolve_scale(
        x="shared", y="shared"
    )
    empty_chart2 = []
    for ab in ab_list2:
        tmp_df = df[df["ab"] == ab]
        base = (
            alt.Chart(tmp_df, title=alt.Title(f"{ab}", anchor="middle"))
            .mark_point(
                filled=True, size=15, color="#ff7f0e", opacity=0.15, stroke="black"
            )
            .encode(
                alt.X(
                    "binding_median",
                    title="EFNB2 Binding",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                alt.Y(
                    "escape_median",
                    title="Antibody Escape",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "binding_median",
                ],
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            )
        )
        empty_chart2.append(base)
    combined_chart2 = alt.hconcat(*empty_chart2, spacing=5).resolve_scale(
        x="shared", y="shared"
    )

    empty_chart3 = []
    for ab in ab_list3:
        tmp_df = df[df["ab"] == ab]
        base3 = (
            alt.Chart(tmp_df, title=alt.Title(f"{ab}", anchor="middle"))
            .mark_point(
                filled=True, size=15, color="#2ca02c", opacity=0.15, stroke="black"
            )
            .encode(
                alt.X(
                    "binding_median",
                    title="EFNB2 Binding",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                alt.Y(
                    "escape_median",
                    title="Antibody Escape",
                    axis=alt.Axis(grid=True, tickCount=3),
                ),
                tooltip=[
                    "site",
                    "wildtype",
                    "mutant",
                    "escape_median",
                    "binding_median",
                ],
                opacity=alt.condition(variant_selector, alt.value(1), alt.value(0.2)),
                size=alt.condition(variant_selector, alt.value(50), alt.value(20)),
                strokeWidth=alt.condition(variant_selector, alt.value(2), alt.value(0)),
            )
        )

    combined_chart_total = alt.vconcat(
        combined_chart1,
        combined_chart2,
        base3,
        title=alt.Title(
            "Antibody Escape versus Binding",
            subtitle="Colored by Epitope. Hover over points to see the same sites",
        ),
    ).add_params(
        variant_selector
    )  
    return combined_chart_total


tmp_img_test = plot_escape_vs_binding(new_merged_df)
tmp_img_test.display()
if mab_line_escape_plot is not None:
    tmp_img_test.save(binding_vs_escape)
In [ ]: